package netutil

import (
	"bytes"
	"io"
	"net"
	"reflect"
	"strings"
	"testing"
)

func TestParseProxyProtocolSuccess(t *testing.T) {
	f := func(body, wantTail []byte, wantAddr net.Addr) {
		t.Helper()
		r := bytes.NewBuffer(body)
		gotAddr, err := readProxyProto(r)
		if err != nil {
			t.Fatalf("unexpected error: %s", err)
		}
		if !reflect.DeepEqual(gotAddr, wantAddr) {
			t.Fatalf("ip not match, got: %v, want: %v", gotAddr, wantAddr)
		}
		gotTail, err := io.ReadAll(r)
		if err != nil {
			t.Fatalf("cannot read tail: %s", err)
		}
		if !bytes.Equal(gotTail, wantTail) {
			t.Fatalf("unexpected tail after parsing proxy protocol\ngot:\n%q\nwant:\n%q", gotTail, wantTail)
		}
	}
	// LOCAL addr
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x20, 0x11, 0x00, 0x0C,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, nil,
		nil)
	// ipv4
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
		// ip data srcid,dstip,srcport,dstport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0}, nil,
		&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80})
	// ipv4 with payload
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
		// ip data
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0,
		// some payload
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0,
	}, []byte{0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0},
		&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80})
	// ipv6
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x24,
		// src and dst ipv6
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		// ports
		0, 80, 0, 0}, nil,
		&net.TCPAddr{IP: net.ParseIP("::1"), Port: 80})
}

func TestParseProxyProtocolFail(t *testing.T) {
	f := func(body []byte) {
		t.Helper()
		r := bytes.NewBuffer(body)
		gotAddr, err := readProxyProto(r)
		if err == nil {
			t.Fatalf("expected error at input %v", body)
		}
		if gotAddr != nil {
			t.Fatalf("expected ip to be nil, got: %v", gotAddr)
		}
	}
	// too short protocol prefix
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A})
	// broken protocol prefix
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21})
	// invalid header
	f([]byte{0x0D, 0x1A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C})
	// invalid version
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x31, 0x11, 0x00, 0x0C})
	// too long block
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0xff, 0x0C})
	// missing bytes in address
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x0C,
		// ip data srcid,dstip,srcport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80})
	// too short address length
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x11, 0x00, 0x08,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0})
	// unsupported family
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x31, 0x00, 0x0C,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// unsupported command
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x22, 0x11, 0x00, 0x0C,
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// mismatch ipv6 and ipv4
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x21, 0x00, 0x0C,
		// ip data srcid,dstip,srcport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// ipv4 udp isn't supported
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x12, 0x00, 0x0C,
		// ip data srcid,dstip,srcport,dstport
		0x7F, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 80, 0, 0})
	// ipv6 udp isn't supported
	f([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, 0x21, 0x22, 0x00, 0x24,
		// src and dst ipv6
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
		// ports
		0, 80, 0, 0})
}

func TestProxyProtocolConnReadWriteSuccessful(t *testing.T) {
	server, client := net.Pipe()
	defer server.Close()
	defer client.Close()

	ppc := newProxyProtocolConn(server)

	expectedData := []byte("Hello, World!")

	// Send proxy protocol header and test data from client
	go func() {
		proxyHeader := []byte{
			0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, // signature
			0x21,       // version 2
			0x11,       // family IPv4
			0x00, 0x0C, // length: 12 bytes (IPv4 + ports)
			192, 168, 1, 100, // source IP
			10, 0, 0, 1, // destination IP
			0x1F, 0x90, // source port 8080
			0x00, 0x50, // destination port 80
		}

		// net.Pipe should not produce an error as it is completely in-memory
		_, _ = client.Write(proxyHeader)
		_, _ = client.Write(expectedData)
	}()

	// Read from proxy protocol connection
	actualData := make([]byte, len(expectedData))
	n, err := ppc.Read(actualData)
	if err != nil {
		t.Fatalf("failed to read from proxy protocol connection: %v", err)
	}
	if n != len(expectedData) {
		t.Fatalf("expected to read %d bytes, got %d", len(expectedData), n)
	}
	if !bytes.Equal(actualData, expectedData) {
		t.Fatalf("expected %q, got %q", expectedData, actualData)
	}

	// Verify the remote address is correctly extracted
	expectedAddr := &net.TCPAddr{
		IP:   net.IPv4(192, 168, 1, 100),
		Port: 8080,
	}
	gotAddr := ppc.RemoteAddr()
	if !reflect.DeepEqual(gotAddr, expectedAddr) {
		t.Fatalf("expected remote addr %v, got %v", expectedAddr, gotAddr)
	}
}

func TestProxyProtocolConnReadWriteFailure(t *testing.T) {
	server, client := net.Pipe()
	defer server.Close()
	defer client.Close()

	ppc := newProxyProtocolConn(server)

	go func() {
		invalidProxyHeader := []byte("GET / HTTP/1.1\r\n\r\n")

		// net.Pipe should not produce an error as it is completely in-memory
		_, _ = client.Write(invalidProxyHeader)
	}()

	buf := make([]byte, 100)
	_, err := ppc.Read(buf)
	if err == nil {
		t.Fatal("expected error when reading from proxy protocol connection; got none")
	}
	if !strings.HasPrefix(err.Error(), `unexpected proxy protocol header`) {
		t.Fatalf("unexpected proxy protocol header error expected; got: %v", err)
	}

	// Should return original remote address on error
	expectedAddr := server.RemoteAddr()
	gotAddr := ppc.RemoteAddr()
	if !reflect.DeepEqual(gotAddr, expectedAddr) {
		t.Fatalf("expected remote addr %v, got %v", expectedAddr, gotAddr)
	}
}
